{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "1a78e1d3-662b-4700-8e5f-8253be012651",
   "metadata": {},
   "source": [
    "This notebook provides a short sample on how to evaluate your trained model. We will continue from our previous notebook `03_` which has trained a working model. \n",
    "\n",
    "The complete evaluation script can be found in the root directory of the repository `eval_direct.py`"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0c94eada-28b8-4aab-a496-0b9fa8b3d8e1",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "%load_ext autoreload\n",
    "%autoreload 2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3d38204f-8434-435e-bc24-b1fefb0bf81d",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "import os\n",
    "import yaml\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "from collections import defaultdict\n",
    "from pathlib import Path\n",
    "from tqdm import tqdm\n",
    "import pickle\n",
    "import copy\n",
    "import xarray as xr\n",
    "import torch\n",
    "from torch.utils.data import DataLoader, TensorDataset\n",
    "import lightning.pytorch as pl\n",
    "from lightning.pytorch import loggers as pl_loggers\n",
    "from lightning.pytorch.callbacks import ModelCheckpoint\n",
    "pl.seed_everything(42)\n",
    "\n",
    "import matplotlib.pyplot as plt\n",
    "import seaborn as sns\n",
    "sns.set_style('whitegrid', {'axes.grid' : False})\n",
    "sns.set_theme(context='paper')\n",
    "\n",
    "\n",
    "# Adjusting global plotting parameters\n",
    "plt.rcParams['font.size'] = 26        # Adjusts the main font size\n",
    "plt.rcParams['axes.labelsize'] = 26   # Font size of x and y labels\n",
    "plt.rcParams['xtick.labelsize'] = 26  # Font size of numbers on x-axis\n",
    "plt.rcParams['ytick.labelsize'] = 26  # Font size of numbers on y-axis\n",
    "plt.rcParams['legend.fontsize'] = 26  # Font size of legend\n",
    "\n",
    "import sys\n",
    "sys.path.append('..')\n",
    "\n",
    "from chaosbench import dataset, config, utils, criterion\n",
    "from chaosbench.models import model, mlp, cnn, ae\n",
    "\n",
    "device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "9957c83f-77b7-4d1a-a7d8-890c6b9e2986",
   "metadata": {
    "tags": []
   },
   "source": [
    "## Prediction Visualization"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "9412bc2f-feb6-4fcc-9240-dc203bb16617",
   "metadata": {},
   "source": [
    "For ClimaX..."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4042260c-c619-4911-b7ec-a8faf3379912",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "#################### CHANGE THIS ####################\n",
    "date_idx = 0\n",
    "param = 'q'\n",
    "level = 700\n",
    "model_name = 'climax'\n",
    "task_num = 2\n",
    "\n",
    "plot_idx = [1, config.N_STEPS - 1]\n",
    "######################################################\n",
    "\n",
    "all_preds = list()\n",
    "all_truth = list()\n",
    "\n",
    "## Dataset: Prediction\n",
    "### List filenames related to the model and task number\n",
    "\n",
    "log_dir = Path('../logs') / model_name\n",
    "preds_filepath = log_dir / 'preds' / f'task{task_num}'\n",
    "preds_files = list(preds_filepath.glob('*.pkl'))\n",
    "preds_files.sort()\n",
    "\n",
    "### Load the actual predictions (only for t = 1, and t = 44)\n",
    "for file_path in np.array(preds_files)[[0, -1]]:\n",
    "    with open(file_path, 'rb') as file:\n",
    "        data = pickle.load(file)\n",
    "        data = data['pred'][f'{param}_{level}']\n",
    "        all_preds.append(data[date_idx])\n",
    "\n",
    "## Dataset: Label\n",
    "param_level_idx = utils.get_param_level_idx(param, level)\n",
    "output_dataset = dataset.S2SObsDataset(years=[2022], n_step=config.N_STEPS)\n",
    "_, output_x, output_y = output_dataset[date_idx]\n",
    "\n",
    "for idx in plot_idx:\n",
    "    all_truth.append(output_y[idx][param_level_idx].detach().cpu().numpy())\n",
    "\n",
    "\n",
    "# Plotting\n",
    "all_truth = np.array(all_truth)\n",
    "all_preds = np.array(all_preds)\n",
    "all_preds = (all_preds - all_preds.mean()) / all_preds.std()\n",
    "\n",
    "f, ax = plt.subplots(3, len(plot_idx), figsize=(8, 3 * len(plot_idx)))\n",
    "\n",
    "for time_idx in range(len(plot_idx)):\n",
    "\n",
    "    im0 = ax[0, time_idx].imshow(all_truth[time_idx], cmap='RdBu_r', vmin=-2, vmax=2)\n",
    "    # ax[0, time_idx].set_title(f'Truth\\n(step={plot_idx[time_idx]})')\n",
    "    ax[0, time_idx].axis('off')\n",
    "    cbar0 = f.colorbar(im0, ax=ax[0, time_idx], shrink=0.8)\n",
    "\n",
    "    im1 = ax[1, time_idx].imshow(all_preds[time_idx], cmap='RdBu_r', vmin=-2, vmax=2)\n",
    "    # ax[1, time_idx].set_title(f'Prediction\\n(step={plot_idx[time_idx]})')\n",
    "    ax[1, time_idx].axis('off')\n",
    "    cbar1 = f.colorbar(im1, ax=ax[1, time_idx], shrink=0.8)\n",
    "\n",
    "    im2 = ax[2, time_idx].imshow(all_preds[time_idx] - all_truth[time_idx], cmap='RdBu_r', vmin=-1, vmax=1)\n",
    "    # ax[2, time_idx].set_title(f'Residual\\n(step={plot_idx[time_idx]})')\n",
    "    ax[2, time_idx].axis('off')\n",
    "    cbar2 = f.colorbar(im2, ax=ax[2, time_idx], shrink=0.8)\n",
    "\n",
    "# Adding titles for each row\n",
    "titles = [r'$Truth$', r'$Prediction$', r'$Residual$']\n",
    "for idx, title in enumerate(titles):\n",
    "    f.text(-0.01, 0.75 - idx*0.28, title, va='center', ha='left', fontsize=12, rotation=90)\n",
    "\n",
    "ax[0,0].set_title(f'Task {task_num} (Direct)\\n norm-{param}{level}\\n t = 1', fontsize=12)\n",
    "ax[0,1].set_title(f'Task {task_num} (Direct)\\n norm-{param}{level}\\n t = 44', fontsize=12)\n",
    "plt.tight_layout()\n",
    "plt.show()\n",
    "f.savefig(f'../docs/preds_{model_name}_{param}{level}_direct_Task {task_num}.pdf', dpi=200, bbox_inches='tight');\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "09c6d261-66f9-451d-ae64-44f7d8e0abaf",
   "metadata": {},
   "source": [
    "For UNet..."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "75fba551-59f2-46d2-b451-932526735993",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "#################### CHANGE THIS ####################\n",
    "date_idx = 0\n",
    "param = 'q'\n",
    "level = 700\n",
    "model_name = 'unet_s2s'\n",
    "task_num = 1\n",
    "\n",
    "DELTA_T = np.array([1, 44])\n",
    "######################################################\n",
    "\n",
    "## Dataset\n",
    "param_level_idx = utils.get_param_level_idx(param, level)\n",
    "input_dataset = dataset.S2SObsDataset(years=[2022], n_step=config.N_STEPS)\n",
    "output_dataset = dataset.S2SObsDataset(years=[2022], n_step=config.N_STEPS)\n",
    "\n",
    "\n",
    "## Load config filepath which consists of all the definition needed to fit/eval a model\n",
    "log_dir = Path('../logs') / model_name\n",
    "model_config_filepath = Path(f'../chaosbench/configs/{model_name}.yaml')\n",
    "\n",
    "with open(model_config_filepath, 'r') as config_filepath:\n",
    "    hyperparams = yaml.load(config_filepath, Loader=yaml.FullLoader)\n",
    "\n",
    "model_args = hyperparams['model_args']\n",
    "data_args = hyperparams['data_args']\n",
    "\n",
    "## Checkpointing\n",
    "version_nums = [0, 12] if task_num == 1 else [2, 21]\n",
    "assert len(version_nums) == len(DELTA_T)\n",
    "\n",
    "baselines = list()\n",
    "for version_num in version_nums:\n",
    "    ckpt_filepath = log_dir / f'lightning_logs/version_{version_num}/checkpoints/'\n",
    "    ckpt_filepath = list(ckpt_filepath.glob('*.ckpt'))[0]\n",
    "    baseline = model.S2SBenchmarkModel(model_args=model_args, data_args=data_args)\n",
    "    baseline = baseline.load_from_checkpoint(ckpt_filepath)\n",
    "    baselines.append(copy.deepcopy(baseline))\n",
    "\n",
    "\n",
    "all_preds = list()\n",
    "all_truth = list()\n",
    "\n",
    "with torch.no_grad():\n",
    "\n",
    "    timestamp, input_x, input_y = input_dataset[date_idx]\n",
    "    _, output_x, output_y = output_dataset[date_idx]\n",
    "\n",
    "    curr_x = input_x.unsqueeze(0).to(device)\n",
    "\n",
    "    for step_idx, delta in enumerate(DELTA_T):\n",
    "        preds = baselines[step_idx](curr_x)\n",
    "        curr_y = output_y.unsqueeze(0)[:, step_idx]\n",
    "        all_preds.append(preds[0][param_level_idx].detach().cpu().numpy())\n",
    "        all_truth.append(curr_y[0][param_level_idx].detach().cpu().numpy())\n",
    "\n",
    "\n",
    "# Plotting\n",
    "all_preds = np.array(all_preds)\n",
    "all_truth = np.array(all_truth)\n",
    "\n",
    "f, ax = plt.subplots(3, len(DELTA_T), figsize=(8, 3 * len(DELTA_T)))\n",
    "\n",
    "for time_idx in range(len(DELTA_T)):\n",
    "\n",
    "    im0 = ax[0, time_idx].imshow(all_truth[time_idx], cmap='RdBu_r', vmin=-2, vmax=2)\n",
    "    # ax[0, time_idx].set_title(f'Truth\\n(step={plot_idx[time_idx]})')\n",
    "    ax[0, time_idx].axis('off')\n",
    "    cbar0 = f.colorbar(im0, ax=ax[0, time_idx], shrink=0.8)\n",
    "\n",
    "    im1 = ax[1, time_idx].imshow(all_preds[time_idx], cmap='RdBu_r', vmin=-2, vmax=2)\n",
    "    # ax[1, time_idx].set_title(f'Prediction\\n(step={plot_idx[time_idx]})')\n",
    "    ax[1, time_idx].axis('off')\n",
    "    cbar1 = f.colorbar(im1, ax=ax[1, time_idx], shrink=0.8)\n",
    "\n",
    "    im2 = ax[2, time_idx].imshow(all_preds[time_idx] - all_truth[time_idx], cmap='RdBu_r', vmin=-1, vmax=1)\n",
    "    # ax[2, time_idx].set_title(f'Residual\\n(step={plot_idx[time_idx]})')\n",
    "    ax[2, time_idx].axis('off')\n",
    "    cbar2 = f.colorbar(im2, ax=ax[2, time_idx], shrink=0.8)\n",
    "\n",
    "# Adding titles for each row\n",
    "titles = [r'$Truth$', r'$Prediction$', r'$Residual$']\n",
    "for idx, title in enumerate(titles):\n",
    "    f.text(-0.01, 0.75 - idx*0.28, title, va='center', ha='left', fontsize=12, rotation=90)\n",
    "\n",
    "ax[0,0].set_title(f'Task {task_num} (Direct)\\n norm-{param}{level}\\n t = 1', fontsize=12)\n",
    "ax[0,1].set_title(f'Task {task_num} (Direct)\\n norm-{param}{level}\\n t = 44', fontsize=12)\n",
    "plt.tight_layout()\n",
    "plt.show()\n",
    "f.savefig(f'../docs/preds_{model_name}_{param}{level}_direct_Task {task_num}.pdf', dpi=200, bbox_inches='tight');\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "77c9134a-ccbc-4c9a-a1ab-51b9b23637e4",
   "metadata": {},
   "source": [
    "## Power Spectrum"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "ae602625-fd23-4c2e-bba9-9ccf8c13a840",
   "metadata": {},
   "source": [
    "### Autoregressive approach - SoTA"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "aea60639-6660-42ed-92ad-fd851390a1d6",
   "metadata": {},
   "source": [
    "For Panguweather..."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "03c512b8-d8dd-4b56-ba89-b603536a7d2d",
   "metadata": {},
   "outputs": [],
   "source": [
    "#################### CHANGE THIS ####################\n",
    "model_name = 'panguweather'\n",
    "date_idx = 0\n",
    "n_steps = [0, 43]\n",
    "param_levels = [['t', 850], ['z', 500], ['q', 700]]\n",
    "######################################################\n",
    "\n",
    "all_Sk = dict()\n",
    "f, ax = plt.subplots(1, len(param_levels), figsize=(10*len(param_levels), 8))\n",
    "\n",
    "## Preprocessing\n",
    "log_dir = Path('../logs') / model_name\n",
    "preds_dataset = xr.open_dataset(\n",
    "    log_dir / f'{model_name}.grib', \n",
    "    backend_kwargs={'filter_by_keys': {'typeOfLevel': 'isobaricInhPa'}}\n",
    ")\n",
    "\n",
    "preds_dataset = preds_dataset.coarsen(step=4, latitude=6, longitude=6, boundary='trim').mean()\n",
    "preds_dataset['z'] = preds_dataset['z'] / 9.8 # to gpm conversion\n",
    "preds_dataset = preds_dataset.interp(latitude=np.linspace( \n",
    "                                        preds_dataset.latitude.values.max(),\n",
    "                                        preds_dataset.latitude.values.min(), 121))\n",
    "\n",
    "for i, param_level in enumerate(param_levels):\n",
    "    param = param_level[0]\n",
    "    level = param_level[1]\n",
    "    param_level_idx = utils.get_param_level_idx(param, level)\n",
    "\n",
    "    print(f'Processing {model_name}: {param}-{level}')\n",
    "\n",
    "    ## Dataset: Preds\n",
    "    all_preds = preds_dataset[param].sel(isobaricInhPa=int(level)).values\n",
    "    all_preds = (all_preds - all_preds.mean()) / all_preds.std()\n",
    "    \n",
    "    ## Dataset: Label\n",
    "    output_dataset = dataset.S2SObsDataset(years=[2016], n_step=config.N_STEPS-1)\n",
    "    _, _, output_y = output_dataset[0]\n",
    "    all_truth = output_y[:, param_level_idx].detach().cpu().numpy()\n",
    "\n",
    "    # Compute power spectrum\n",
    "    curr_pred_Sk, curr_truth_Sk = list(), list()\n",
    "\n",
    "    for step_idx in range(all_truth.shape[0]):\n",
    "        pred_t, truth_t = all_preds[step_idx], all_truth[step_idx]\n",
    "        pred_power_t, truth_power_t = np.fft.fft2(pred_t), np.fft.fft2(truth_t)\n",
    "        pred_power_t, truth_power_t = np.abs(pred_power_t)**2, np.abs(truth_power_t)**2\n",
    "\n",
    "        ny, nx = pred_t.shape\n",
    "        kx = np.fft.fftfreq(nx) * nx\n",
    "        ky = np.fft.fftfreq(ny) * ny\n",
    "\n",
    "        kx, ky = np.meshgrid(kx, ky)\n",
    "        k = np.sqrt(kx**2 + ky**2)\n",
    "\n",
    "        k_bins = np.arange(0.5, np.max(k), 1)\n",
    "        k_bin_centers = 0.5 * (k_bins[:-1] + k_bins[1:])\n",
    "        pred_Sk = np.histogram(k, bins=k_bins, weights=pred_power_t)[0] / np.histogram(k, bins=k_bins)[0]\n",
    "        truth_Sk = np.histogram(k, bins=k_bins, weights=truth_power_t)[0] / np.histogram(k, bins=k_bins)[0]\n",
    "\n",
    "        curr_pred_Sk.append(pred_Sk)\n",
    "        curr_truth_Sk.append(truth_Sk)\n",
    "        \n",
    "        # Plot power spectrum\n",
    "        if step_idx in n_steps:\n",
    "            step_num = n_steps.index(step_idx)\n",
    "            ax[i].set_title(f'{param}-{level}', fontsize=40)\n",
    "            ax[i].loglog(pred_Sk, label=f'{step_idx + 1}-day ahead', linewidth=3)\n",
    "            ax[i].set_xlabel('Wavenumber, k')\n",
    "            ax[i].set_ylabel(f'Power, S(k)')\n",
    "            ax[i].set_ylim([10**0, 10**7])\n",
    "            ax[i].legend()\n",
    "\n",
    "    all_Sk[f'{param}-{level}'] = np.array(curr_pred_Sk)\n",
    "    all_Sk['truth'] = np.array(curr_truth_Sk)\n",
    "                \n",
    "plt.show()\n",
    "f.savefig(f'../docs/specdiv_{model_name}_sota.pdf', dpi=200, bbox_inches='tight');"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "abd666c1-e033-48db-9b35-d3dc7223eec1",
   "metadata": {},
   "source": [
    "For Graphcast"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ec28363e-b20a-4525-b737-2e6b5605805f",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "#################### CHANGE THIS ####################\n",
    "model_name = 'graphcast'\n",
    "date_idx = 0\n",
    "n_steps = [0, 43]\n",
    "param_levels = [['t', 850], ['z', 500], ['q', 700]]\n",
    "######################################################\n",
    "\n",
    "all_Sk = dict()\n",
    "f, ax = plt.subplots(1, len(param_levels), figsize=(10*len(param_levels), 8))\n",
    "\n",
    "## Preprocessing\n",
    "log_dir = Path('../logs') / model_name\n",
    "preds_dataset = xr.open_dataset(\n",
    "    log_dir / f'{model_name}.grib', \n",
    "    backend_kwargs={'filter_by_keys': {'typeOfLevel': 'isobaricInhPa'}}\n",
    ")\n",
    "\n",
    "preds_dataset = preds_dataset.coarsen(step=4, latitude=6, longitude=6, boundary='trim').mean()\n",
    "preds_dataset['z'] = preds_dataset['z'] / 9.8 # to gpm conversion\n",
    "preds_dataset = preds_dataset.interp(latitude=np.linspace( \n",
    "                                        preds_dataset.latitude.values.max(),\n",
    "                                        preds_dataset.latitude.values.min(), 121))\n",
    "\n",
    "for i, param_level in enumerate(param_levels):\n",
    "    param = param_level[0]\n",
    "    level = param_level[1]\n",
    "    param_level_idx = utils.get_param_level_idx(param, level)\n",
    "\n",
    "    print(f'Processing {model_name}: {param}-{level}')\n",
    "\n",
    "    ## Dataset: Preds\n",
    "    all_preds = preds_dataset[param].sel(isobaricInhPa=int(level)).values\n",
    "    all_preds = (all_preds - all_preds.mean()) / all_preds.std()\n",
    "    \n",
    "    ## Dataset: Label\n",
    "    output_dataset = dataset.S2SObsDataset(years=[2016], n_step=config.N_STEPS-1)\n",
    "    _, _, output_y = output_dataset[0]\n",
    "    all_truth = output_y[:, param_level_idx].detach().cpu().numpy()\n",
    "\n",
    "    # Compute power spectrum\n",
    "    curr_pred_Sk, curr_truth_Sk = list(), list()\n",
    "\n",
    "    for step_idx in range(all_truth.shape[0]):\n",
    "        pred_t, truth_t = all_preds[step_idx], all_truth[step_idx]\n",
    "        pred_power_t, truth_power_t = np.fft.fft2(pred_t), np.fft.fft2(truth_t)\n",
    "        pred_power_t, truth_power_t = np.abs(pred_power_t)**2, np.abs(truth_power_t)**2\n",
    "\n",
    "        ny, nx = pred_t.shape\n",
    "        kx = np.fft.fftfreq(nx) * nx\n",
    "        ky = np.fft.fftfreq(ny) * ny\n",
    "\n",
    "        kx, ky = np.meshgrid(kx, ky)\n",
    "        k = np.sqrt(kx**2 + ky**2)\n",
    "\n",
    "        k_bins = np.arange(0.5, np.max(k), 1)\n",
    "        k_bin_centers = 0.5 * (k_bins[:-1] + k_bins[1:])\n",
    "        pred_Sk = np.histogram(k, bins=k_bins, weights=pred_power_t)[0] / np.histogram(k, bins=k_bins)[0]\n",
    "        truth_Sk = np.histogram(k, bins=k_bins, weights=truth_power_t)[0] / np.histogram(k, bins=k_bins)[0]\n",
    "\n",
    "        curr_pred_Sk.append(pred_Sk)\n",
    "        curr_truth_Sk.append(truth_Sk)\n",
    "        \n",
    "        # Plot power spectrum\n",
    "        if step_idx in n_steps:\n",
    "            step_num = n_steps.index(step_idx)\n",
    "            ax[i].set_title(f'{param}-{level}', fontsize=40)\n",
    "            ax[i].loglog(pred_Sk, label=f'{step_idx + 1}-day ahead', linewidth=3)\n",
    "            ax[i].set_xlabel('Wavenumber, k')\n",
    "            ax[i].set_ylabel(f'Power, S(k)')\n",
    "            ax[i].set_ylim([10**0, 10**7])\n",
    "            ax[i].legend()\n",
    "\n",
    "    all_Sk[f'{param}-{level}'] = np.array(curr_pred_Sk)\n",
    "    all_Sk['truth'] = np.array(curr_truth_Sk)\n",
    "                \n",
    "plt.show()\n",
    "f.savefig(f'../docs/specdiv_{model_name}_sota.pdf', dpi=200, bbox_inches='tight');"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "1ddf2d02-36a9-4f68-9866-15b1099ece13",
   "metadata": {},
   "source": [
    "For Fourcastnetv2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0364b510-bb34-49f7-9675-2e091a30d0db",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "#################### CHANGE THIS ####################\n",
    "model_name = 'fourcastnetv2'\n",
    "date_idx = 0\n",
    "n_steps = [0, 43]\n",
    "param_levels = [['t', 850], ['z', 500]]\n",
    "######################################################\n",
    "\n",
    "all_Sk = dict()\n",
    "f, ax = plt.subplots(1, len(param_levels), figsize=(10*len(param_levels), 8))\n",
    "\n",
    "## Preprocessing\n",
    "log_dir = Path('../logs') / model_name\n",
    "preds_dataset = xr.open_dataset(\n",
    "    log_dir / f'{model_name}-small.grib', \n",
    "    backend_kwargs={'filter_by_keys': {'typeOfLevel': 'isobaricInhPa'}}\n",
    ")\n",
    "\n",
    "preds_dataset = preds_dataset.coarsen(step=4, latitude=6, longitude=6, boundary='trim').mean()\n",
    "preds_dataset['z'] = preds_dataset['z'] / 9.8 # to gpm conversion\n",
    "preds_dataset = preds_dataset.interp(latitude=np.linspace( \n",
    "                                        preds_dataset.latitude.values.max(),\n",
    "                                        preds_dataset.latitude.values.min(), 121))\n",
    "\n",
    "for i, param_level in enumerate(param_levels):\n",
    "    param = param_level[0]\n",
    "    level = param_level[1]\n",
    "    param_level_idx = utils.get_param_level_idx(param, level)\n",
    "\n",
    "    print(f'Processing {model_name}: {param}-{level}')\n",
    "\n",
    "    ## Dataset: Preds\n",
    "    all_preds = preds_dataset[param].sel(isobaricInhPa=int(level)).values\n",
    "    all_preds = (all_preds - all_preds.mean()) / all_preds.std()\n",
    "    \n",
    "    ## Dataset: Label\n",
    "    output_dataset = dataset.S2SObsDataset(years=[2016], n_step=config.N_STEPS-1)\n",
    "    _, _, output_y = output_dataset[0]\n",
    "    all_truth = output_y[:, param_level_idx].detach().cpu().numpy()\n",
    "\n",
    "    # Compute power spectrum\n",
    "    curr_pred_Sk, curr_truth_Sk = list(), list()\n",
    "\n",
    "    for step_idx in range(all_truth.shape[0]):\n",
    "        pred_t, truth_t = all_preds[step_idx], all_truth[step_idx]\n",
    "        pred_power_t, truth_power_t = np.fft.fft2(pred_t), np.fft.fft2(truth_t)\n",
    "        pred_power_t, truth_power_t = np.abs(pred_power_t)**2, np.abs(truth_power_t)**2\n",
    "\n",
    "        ny, nx = pred_t.shape\n",
    "        kx = np.fft.fftfreq(nx) * nx\n",
    "        ky = np.fft.fftfreq(ny) * ny\n",
    "\n",
    "        kx, ky = np.meshgrid(kx, ky)\n",
    "        k = np.sqrt(kx**2 + ky**2)\n",
    "\n",
    "        k_bins = np.arange(0.5, np.max(k), 1)\n",
    "        k_bin_centers = 0.5 * (k_bins[:-1] + k_bins[1:])\n",
    "        pred_Sk = np.histogram(k, bins=k_bins, weights=pred_power_t)[0] / np.histogram(k, bins=k_bins)[0]\n",
    "        truth_Sk = np.histogram(k, bins=k_bins, weights=truth_power_t)[0] / np.histogram(k, bins=k_bins)[0]\n",
    "\n",
    "        curr_pred_Sk.append(pred_Sk)\n",
    "        curr_truth_Sk.append(truth_Sk)\n",
    "        \n",
    "        # Plot power spectrum\n",
    "        if step_idx in n_steps:\n",
    "            step_num = n_steps.index(step_idx)\n",
    "            ax[i].set_title(f'{param}-{level}', fontsize=40)\n",
    "            ax[i].loglog(pred_Sk, label=f'{step_idx + 1}-day ahead', linewidth=3)\n",
    "            ax[i].set_xlabel('Wavenumber, k')\n",
    "            ax[i].set_ylabel(f'Power, S(k)')\n",
    "            ax[i].set_ylim([10**0, 10**7])\n",
    "            ax[i].legend()\n",
    "\n",
    "    all_Sk[f'{param}-{level}'] = np.array(curr_pred_Sk)\n",
    "    all_Sk['truth'] = np.array(curr_truth_Sk)\n",
    "                \n",
    "plt.show()\n",
    "f.savefig(f'../docs/specdiv_{model_name}_sota.pdf', dpi=200, bbox_inches='tight');"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6b6d9363-0712-4fa8-b43b-9f5ad07df1c8",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "# Plot full power spectra\n",
    "import matplotlib as mpl\n",
    "mpl.rcParams.update(mpl.rcParamsDefault)\n",
    "\n",
    "def set_log_ticks_10_power(axis, num_ticks=5):\n",
    "    lims = axis.get_data_interval()\n",
    "    \n",
    "    ticks = np.linspace(lims[0], lims[1], num=num_ticks)\n",
    "    axis.set_ticks(ticks)\n",
    "    axis.set_ticklabels([f'$10^{{{int(tick)}}}$' for tick in ticks])\n",
    "\n",
    "\n",
    "eps = 1e-50 # to void log(0)\n",
    "f = plt.figure(figsize=(16, 10))\n",
    "subplot_idx = 1\n",
    "    \n",
    "for param_id, param_level in enumerate(param_levels):\n",
    "\n",
    "    curr_Sk = all_Sk[f'{param_level[0]}-{param_level[1]}']\n",
    "\n",
    "    # Plot the 3D contour plot\n",
    "    ax = f.add_subplot(1, len(param_levels), subplot_idx, projection='3d')\n",
    "    Wavenumber, Timestep = np.meshgrid(np.arange(1, curr_Sk.shape[1] + 1), np.arange(1, curr_Sk.shape[0] + 1))\n",
    "    contour = ax.contour3D(np.log10(Wavenumber + eps), Timestep, np.log10(curr_Sk + eps), 100, cmap='inferno_r')\n",
    "\n",
    "    ax.set_xlabel('Wavenumber, k')\n",
    "    ax.set_ylabel('Number of days ahead')\n",
    "    ax.set_zlabel(r'Power, S(k)', labelpad=0.1)\n",
    "\n",
    "    set_log_ticks_10_power(ax.xaxis)\n",
    "    set_log_ticks_10_power(ax.zaxis)\n",
    "\n",
    "    ax.set_title(f'{param_level[0]}{param_level[1]}', fontsize=12)\n",
    "    subplot_idx += 1\n",
    "\n",
    "plt.show()\n",
    "# f.savefig(f'../docs/3d_{model_name}_sota.pdf', dpi=200, bbox_inches='tight');"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "295f5880-d49f-408e-9937-8b9a2f57031c",
   "metadata": {},
   "source": [
    "### Direct approach "
   ]
  },
  {
   "cell_type": "markdown",
   "id": "4cbce70c-be08-4e9a-b75b-a7a5ace29428",
   "metadata": {},
   "source": [
    "For ClimaX..."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1ee37c51-3465-47d2-a3e3-3aedd31b8c5d",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "#################### CHANGE THIS ####################\n",
    "date_idx = 0\n",
    "delta_t = np.array([1, 5, 10, 15, 20, 25, 30, 35, 40, 44])\n",
    "n_steps = np.arange(len(delta_t))\n",
    "model_name = 'climax'\n",
    "task_num = 1\n",
    "model_spec = f'Task_{task_num}/Direct'\n",
    "######################################################\n",
    "\n",
    "param_levels = [['t', 850], ['z', 500], ['q', 700]]\n",
    "all_Sk = dict()\n",
    "\n",
    "## Initialize dataset\n",
    "output_dataset = dataset.S2SObsDataset(years=[2022], n_step=config.N_STEPS)\n",
    "\n",
    "## For plotting\n",
    "f, ax = plt.subplots(1, len(param_levels), figsize=(10*len(param_levels), 8))\n",
    "\n",
    "for i, param_level in enumerate(param_levels):\n",
    "    param = param_level[0]\n",
    "    level = param_level[1]\n",
    "    param_level_idx = utils.get_param_level_idx(param, level)\n",
    "\n",
    "    print(f'Processing {model_spec}')\n",
    "\n",
    "    all_preds = list()\n",
    "    all_truth = list()\n",
    "\n",
    "    ## Dataset: Prediction\n",
    "    ### List filenames related to the model and task number\n",
    "\n",
    "    log_dir = Path('../logs') / model_name\n",
    "    preds_filepath = log_dir / 'preds' / f'task{task_num}'\n",
    "    preds_files = list(preds_filepath.glob('*.pkl'))\n",
    "    preds_files.sort()\n",
    "\n",
    "    ### Load the actual predictions (only for t = 1, and t = 44)\n",
    "    for file_path in np.array(preds_files)[n_steps]:\n",
    "        with open(file_path, 'rb') as file:\n",
    "            data = pickle.load(file)\n",
    "            data = data['pred'][f'{param}_{level}']\n",
    "            all_preds.append(data[date_idx])\n",
    "\n",
    "    ## Dataset: Label\n",
    "    output_dataset = dataset.S2SObsDataset(years=[2022], n_step=config.N_STEPS)\n",
    "    _, output_x, output_y = output_dataset[date_idx]\n",
    "\n",
    "    for step_idx in n_steps:\n",
    "        all_truth.append(output_y[step_idx][param_level_idx].detach().cpu().numpy())\n",
    "\n",
    "    ## Post-process truth and predictions (have to normalize this since it comes in real values)\n",
    "    all_truth = np.array(all_truth)\n",
    "    all_preds = np.array(all_preds)\n",
    "    all_preds = (all_preds - all_preds.mean()) / all_preds.std()\n",
    "\n",
    "    # Compute power spectrum\n",
    "    curr_pred_Sk, curr_truth_Sk = list(), list()\n",
    "\n",
    "    for step_idx in range(all_preds.shape[0]):\n",
    "        pred_t, truth_t = all_preds[step_idx], all_truth[step_idx]\n",
    "        pred_power_t, truth_power_t = np.fft.fft2(pred_t), np.fft.fft2(truth_t)\n",
    "        pred_power_t, truth_power_t = np.abs(pred_power_t)**2, np.abs(truth_power_t)**2\n",
    "\n",
    "        ny, nx = pred_t.shape\n",
    "        kx = np.fft.fftfreq(nx) * nx\n",
    "        ky = np.fft.fftfreq(ny) * ny\n",
    "\n",
    "        kx, ky = np.meshgrid(kx, ky)\n",
    "        k = np.sqrt(kx**2 + ky**2)\n",
    "\n",
    "        k_bins = np.arange(0.5, np.max(k), 1)\n",
    "        k_bin_centers = 0.5 * (k_bins[:-1] + k_bins[1:])\n",
    "        pred_Sk = np.histogram(k, bins=k_bins, weights=pred_power_t)[0] / np.histogram(k, bins=k_bins)[0]\n",
    "        truth_Sk = np.histogram(k, bins=k_bins, weights=truth_power_t)[0] / np.histogram(k, bins=k_bins)[0]\n",
    "\n",
    "        curr_pred_Sk.append(pred_Sk)\n",
    "        curr_truth_Sk.append(truth_Sk)\n",
    "        \n",
    "        # Plot power spectrum\n",
    "        ax[i].set_title(f'{param}-{level}', fontsize=40)\n",
    "        ax[i].loglog(pred_Sk, label=f'{delta_t[step_idx]}-day ahead', linewidth=3)\n",
    "        ax[i].set_xlabel('Wavenumber, k')\n",
    "        ax[i].set_ylabel(f'Power, S(k)')\n",
    "        ax[i].set_ylim([10**0, 10**7])\n",
    "        ax[i].legend()\n",
    "\n",
    "    all_Sk[f'{model_spec}:{param}-{level}'] = np.array(curr_pred_Sk)\n",
    "    all_Sk['truth'] = np.array(curr_truth_Sk)\n",
    "                \n",
    "plt.show()\n",
    "f.savefig(f'../docs/specdiv_{model_name}_direct_Task {task_num}.pdf', dpi=200, bbox_inches='tight');"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "a4dd8ce6-7764-4b85-9b38-b2367347931b",
   "metadata": {},
   "source": [
    "For UNet..."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0cfa814e-980a-49ba-86d5-e4ade059f9bc",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "#################### CHANGE THIS ####################\n",
    "date_idx = 0\n",
    "delta_t = np.array([1, 5, 10, 15, 20, 25, 30, 35, 40, 44])\n",
    "n_steps = np.arange(len(delta_t))\n",
    "model_name = 'unet_s2s'\n",
    "task_num = 1\n",
    "model_spec = f'Task_{task_num}/Direct'\n",
    "######################################################\n",
    "\n",
    "param_levels = [['t', 850], ['z', 500], ['q', 700]]\n",
    "all_Sk = dict()\n",
    "\n",
    "## Initialize dataset\n",
    "output_dataset = dataset.S2SObsDataset(years=[2022], n_step=config.N_STEPS)\n",
    "\n",
    "## For plotting\n",
    "f, ax = plt.subplots(1, len(param_levels), figsize=(10*len(param_levels), 8))\n",
    "\n",
    "for i, param_level in enumerate(param_levels):\n",
    "    param = param_level[0]\n",
    "    level = param_level[1]\n",
    "    param_level_idx = utils.get_param_level_idx(param, level)\n",
    "\n",
    "    print(f'Processing {model_spec}')\n",
    "\n",
    "    all_preds = list()\n",
    "    all_truth = list()\n",
    "\n",
    "    ## Dataset\n",
    "    input_dataset = dataset.S2SObsDataset(years=[2022], n_step=config.N_STEPS)\n",
    "    output_dataset = dataset.S2SObsDataset(years=[2022], n_step=config.N_STEPS)\n",
    "\n",
    "    ## Load config filepath which consists of all the definition needed to fit/eval a model\n",
    "    log_dir = Path('../logs') / model_name\n",
    "    model_config_filepath = Path(f'../chaosbench/configs/{model_name}.yaml')\n",
    "\n",
    "    with open(model_config_filepath, 'r') as config_filepath:\n",
    "        hyperparams = yaml.load(config_filepath, Loader=yaml.FullLoader)\n",
    "\n",
    "    model_args = hyperparams['model_args']\n",
    "    data_args = hyperparams['data_args']\n",
    "\n",
    "    ## Checkpointing\n",
    "    version_nums = [0,4,5,6,7,8,9,10,11,12] if task_num == 1 else [2,13,14,15,16,17,18,19,20,21]\n",
    "    assert len(version_nums) == len(delta_t)\n",
    "\n",
    "    baselines = list()\n",
    "    for version_num in version_nums:\n",
    "        ckpt_filepath = log_dir / f'lightning_logs/version_{version_num}/checkpoints/'\n",
    "        ckpt_filepath = list(ckpt_filepath.glob('*.ckpt'))[0]\n",
    "        baseline = model.S2SBenchmarkModel(model_args=model_args, data_args=data_args)\n",
    "        baseline = baseline.load_from_checkpoint(ckpt_filepath)\n",
    "        baselines.append(copy.deepcopy(baseline))\n",
    "\n",
    "\n",
    "    all_preds = list()\n",
    "    all_truth = list()\n",
    "\n",
    "    with torch.no_grad():\n",
    "\n",
    "        timestamp, input_x, input_y = input_dataset[date_idx]\n",
    "        _, output_x, output_y = output_dataset[date_idx]\n",
    "\n",
    "        curr_x = input_x.unsqueeze(0).to(device)\n",
    "\n",
    "        for step_idx, delta in enumerate(delta_t):\n",
    "            preds = baselines[step_idx](curr_x)\n",
    "            curr_y = output_y.unsqueeze(0)[:, step_idx]\n",
    "            all_preds.append(preds[0][param_level_idx].detach().cpu().numpy())\n",
    "            all_truth.append(curr_y[0][param_level_idx].detach().cpu().numpy())\n",
    "\n",
    "    all_preds = np.array(all_preds)\n",
    "    all_truth = np.array(all_truth)\n",
    "    all_preds = (all_preds - all_preds.mean()) / all_preds.std()\n",
    "\n",
    "    # Compute power spectrum\n",
    "    curr_pred_Sk, curr_truth_Sk = list(), list()\n",
    "\n",
    "    for step_idx in range(all_preds.shape[0]):\n",
    "        pred_t, truth_t = all_preds[step_idx], all_truth[step_idx]\n",
    "        pred_power_t, truth_power_t = np.fft.fft2(pred_t), np.fft.fft2(truth_t)\n",
    "        pred_power_t, truth_power_t = np.abs(pred_power_t)**2, np.abs(truth_power_t)**2\n",
    "\n",
    "        ny, nx = pred_t.shape\n",
    "        kx = np.fft.fftfreq(nx) * nx\n",
    "        ky = np.fft.fftfreq(ny) * ny\n",
    "\n",
    "        kx, ky = np.meshgrid(kx, ky)\n",
    "        k = np.sqrt(kx**2 + ky**2)\n",
    "\n",
    "        k_bins = np.arange(0.5, np.max(k), 1)\n",
    "        k_bin_centers = 0.5 * (k_bins[:-1] + k_bins[1:])\n",
    "        pred_Sk = np.histogram(k, bins=k_bins, weights=pred_power_t)[0] / np.histogram(k, bins=k_bins)[0]\n",
    "        truth_Sk = np.histogram(k, bins=k_bins, weights=truth_power_t)[0] / np.histogram(k, bins=k_bins)[0]\n",
    "\n",
    "        curr_pred_Sk.append(pred_Sk)\n",
    "        curr_truth_Sk.append(truth_Sk)\n",
    "        \n",
    "        # Plot power spectrum\n",
    "        ax[i].set_title(f'{param}-{level}', fontsize=40)\n",
    "        ax[i].loglog(pred_Sk, label=f'{delta_t[step_idx]}-day ahead', linewidth=3)\n",
    "        ax[i].set_xlabel('Wavenumber, k')\n",
    "        ax[i].set_ylabel(f'Power, S(k)')\n",
    "        ax[i].set_ylim([10**0, 10**7])\n",
    "        ax[i].legend()\n",
    "\n",
    "    all_Sk[f'{model_spec}:{param}-{level}'] = np.array(curr_pred_Sk)\n",
    "    all_Sk['truth'] = np.array(curr_truth_Sk)\n",
    "                \n",
    "plt.show()\n",
    "f.savefig(f'../docs/specdiv_{model_name}_direct_Task {task_num}.pdf', dpi=200, bbox_inches='tight');"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6e47b4d6-9614-40a9-af7d-7e13c4d1ff87",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "# Plot full power spectra\n",
    "import matplotlib as mpl\n",
    "mpl.rcParams.update(mpl.rcParamsDefault)\n",
    "\n",
    "def set_log_ticks_10_power(axis, num_ticks=5):\n",
    "    lims = axis.get_data_interval()\n",
    "    \n",
    "    ticks = np.linspace(lims[0], lims[1], num=num_ticks)\n",
    "    axis.set_ticks(ticks)\n",
    "    axis.set_ticklabels([f'$10^{{{int(tick)}}}$' for tick in ticks])\n",
    "\n",
    "\n",
    "eps = 1e-50 # to void log(0)\n",
    "f = plt.figure(figsize=(16, 10))\n",
    "subplot_idx = 1\n",
    "    \n",
    "for param_id, param_level in enumerate(param_levels):\n",
    "\n",
    "    curr_Sk = all_Sk[f'{model_spec}:{param_level[0]}-{param_level[1]}']\n",
    "\n",
    "    # Plot the 3D contour plot\n",
    "    ax = f.add_subplot(1, len(param_levels), subplot_idx, projection='3d')\n",
    "    Wavenumber, Timestep = np.meshgrid(np.arange(1, curr_Sk.shape[1] + 1), delta_t)\n",
    "    contour = ax.contour3D(np.log10(Wavenumber + eps), Timestep, np.log10(curr_Sk + eps), 100, cmap='inferno_r')\n",
    "\n",
    "    ax.set_xlabel('Wavenumber, k')\n",
    "    ax.set_ylabel('Number of days ahead')\n",
    "    ax.set_zlabel(r'Power, S(k)', labelpad=0.1)\n",
    "\n",
    "    set_log_ticks_10_power(ax.xaxis)\n",
    "    set_log_ticks_10_power(ax.zaxis)\n",
    "\n",
    "    ax.set_title(f'{model_spec}\\n {param_level[0]}{param_level[1]}', fontsize=12)\n",
    "    subplot_idx += 1\n",
    "\n",
    "plt.show()\n",
    "f.savefig(f'../docs/3d_{model_name}_sota.pdf', dpi=200, bbox_inches='tight');"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "aa33258e-9870-46ec-94e3-38a14957c9be",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "bench",
   "language": "python",
   "name": "bench"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.18"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
